import os
import sys
sys.path.append('./')
import glob
import argparse

import json
import joblib
import torch
import numpy as np
from smplx import SMPL, SMPLX
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
import config as _C


DISCRETE_CONTACT_THR = 0.3
SMPL_MODEL_DIR = f"/home/{os.getenv('USER')}/Data/body_models/smpl"
SMPLX_MODEL_DIR = f"/home/{os.getenv('USER')}/Data/body_models/smplx"
SMPLX2SMPL_PTH = f"/home/{os.getenv('USER')}/Data/body_models/smplx2smpl.pkl"

VERTS_SEGMENT_IDXS = "data/smplx_part_idxs_bcd.json"

test_sequences = ["P01_Seq01", 
                  "P02_Seq02", 
                  "P02_Seq03", 
                  "P03_Seq03", 
                  "P05_Seq01"
]

SEGMENT_NAME_LIST = [_C.SENSOR_NAME_MAPPER[k] for k in _C.SENSOR_NAME_LIST]
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', default='deco')
    args = parser.parse_args()
    smpl_model = SMPL(SMPL_MODEL_DIR, gender="neutral", num_betas=10).eval()

    smplx_part_segm = json.load(open(VERTS_SEGMENT_IDXS, "rb"))
    smplx = SMPLX(SMPLX_MODEL_DIR).eval()
    smplx2smpl = torch.from_numpy(
        joblib.load(_C.SMPLX2SMPL_PTH)["matrix"]
    ).float()

    # Make smpl part segmentation
    results = dict(
        precision_full=[],
        recall_full=[],
        f1_full=[],
        precision_hand=[],
        recall_hand=[],
        f1_hand=[],
        precision_foot=[],
        recall_foot=[],
        f1_foot=[],
    )
    all_preds, all_labels = [], []
    smpl_part_segm = dict()
    for name, idxs in smplx_part_segm.items():
        corresponded_idxs = np.where(smplx2smpl[:, idxs].sum(-1) > 0.6)[0]
        smpl_part_segm[name] = corresponded_idxs

    for sequence in tqdm(test_sequences):

        annotation_pth = os.path.join(_C.DATA_BASE_DIR, sequence, f"{sequence}_annot.pkl")
        annotation = joblib.load(annotation)
        groundtruth = annotation["contact"].copy()
        max_frame_id = len(groundtruth)
        
        for camera_name in _C.EXO_CAMERA_NAMES:
            camera_i = int(camera_name[-1]) - 1

            if not camera_i in annotation["camera_idxs"]:
                continue

            bboxes = annotation["bboxes"]
            valid_idxs = np.wherer(bboxes[:, -1] != -1)[0]
            
            prediction_pth = "path-to-your-prediction-path"
            pred = joblib.load(prediction_pth)
            
            dense_pred = pred["contact"].copy()
            if len(dense_pred.shape) == 3:
                dense_pred = dense_pred[..., 0]
            discrete_pred = np.zeros_like(groundtruth)[valid_idxs]
            discrete_gt = groundtruth[valid_idxs]
            for seg_i, (seg_name, seg_idxs) in enumerate(smpl_part_segm.items()):
                _seg_i = SEGMENT_NAME_LIST.index(seg_name)
                per_segment_contact = dense_pred[:, seg_idxs].copy()
                per_segment_contact_ratio = per_segment_contact.astype(float).mean(-1)
                
                discrete_pred[:, _seg_i] = per_segment_contact_ratio > DISCRETE_CONTACT_THR

            all_preds.append(discrete_pred)
            all_labels.append(discrete_gt)
            
            # Full-body results
            precision = precision_score(discrete_gt.flatten(), discrete_pred.flatten())
            recall = recall_score(discrete_gt.flatten(), discrete_pred.flatten())
            f1 = f1_score(discrete_gt.flatten(), discrete_pred.flatten())
            
            results["precision_full"].append(precision)
            results["recall_full"].append(recall)
            results["f1_full"].append(f1)

            # Hand results
            precision = precision_score(discrete_gt[:, [22, 23]].flatten(), discrete_pred[:, [22, 23]].flatten())
            recall = recall_score(discrete_gt[:, [22, 23]].flatten(), discrete_pred[:, [22, 23]].flatten())
            f1 = f1_score(discrete_gt[:, [22, 23]].flatten(), discrete_pred[:, [22, 23]].flatten())
            
            results["precision_hand"].append(precision)
            results["recall_hand"].append(recall)
            results["f1_hand"].append(f1)

            # Foot results
            precision = precision_score(discrete_gt[:, [8, 9, 10, 11]].flatten(), discrete_pred[:, [8, 9, 10, 11]].flatten())
            recall = recall_score(discrete_gt[:, [8, 9, 10, 11]].flatten(), discrete_pred[:, [8, 9, 10, 11]].flatten())
            f1 = f1_score(discrete_gt[:, [8, 9, 10, 11]].flatten(), discrete_pred[:, [8, 9, 10, 11]].flatten())
            
            results["precision_foot"].append(precision)
            results["recall_foot"].append(recall)
            results["f1_foot"].append(f1)

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    
    precision_all = precision_score(all_labels.reshape(-1), all_preds.reshape(-1))
    precision_hand = precision_score(all_labels[:, [22, 23]].reshape(-1), all_preds[:, [22, 23]].reshape(-1))
    precision_foot = precision_score(all_labels[:, [8, 9, 10, 11]].reshape(-1), all_preds[:, [8, 9, 10, 11]].reshape(-1))

    recall_all = recall_score(all_labels.reshape(-1), all_preds.reshape(-1))
    recall_hand = recall_score(all_labels[:, [22, 23]].reshape(-1), all_preds[:, [22, 23]].reshape(-1))
    recall_foot = recall_score(all_labels[:, [8, 9, 10, 11]].reshape(-1), all_preds[:, [8, 9, 10, 11]].reshape(-1))
    
    f1_all = f1_score(all_labels.reshape(-1), all_preds.reshape(-1))
    f1_hand = f1_score(all_labels[:, [22, 23]].reshape(-1), all_preds[:, [22, 23]].reshape(-1))
    f1_foot = f1_score(all_labels[:, [8, 9, 10, 11]].reshape(-1), all_preds[:, [8, 9, 10, 11]].reshape(-1))

    for (precision, recall, f1, _type) in zip((precision_all, precision_hand, precision_foot), 
                                     (recall_all, recall_hand, recall_foot), 
                                     (f1_all, f1_hand, f1_foot), 
                                     ("all", "hand", "foot")):
        print(f"Type: {_type}  |  Precision {precision:.4f}   Recall: {recall:.4f}   F1: {f1:.4f}")